import os
import openai
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from dotenv import load_dotenv

from utils.util import read_txt

load_dotenv()
openai.api_key = os.getenv('OPENAI_API_KEY')

class Retrieval:
    def __init__(self, domain) -> None:
        self.prompt = PromptTemplate.from_template(read_txt("planning/data/prompt/similar_protocol_retrieval.txt"))
        self.vectorstore = Chroma(persist_directory=f"planning/data/{domain}_sampled_vectorstore/", embedding_function=OpenAIEmbeddings(model="text-embedding-ada-002"))
        self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
        self.llm = ChatOpenAI(model="gpt-4o-mini")

    def run_query(self, query):
        rag_chain = (
            {"context": self.retriever | self.__format_docs, "title": RunnablePassthrough()}
            | self.prompt
            | self.llm
            | StrOutputParser()
        )
        response = rag_chain.invoke(query)
        return self.__process(response)

    def __format_docs(self, docs):
        return "\n\n".join(doc.page_content for doc in docs)
    
    def __process(self, response):
        ids = response.strip().split(",")
        return [id.strip() for id in ids]
